import random
import gym
import numpy as np
import torch
import rl_utils


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
seedseed = 0
random.seed(seedseed)
np.random.seed(seedseed)
torch.manual_seed(seedseed)

'''parameters'''
hidden_dim = 128
discount_factor = 0.9
num_episodes = 2000

'''env'''
env_name = 'Pendulum-v1'
env = gym.make(env_name)
env.reset(seed=seedseed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_high = env.action_space.high[0]  # 动作最大值
action_low = env.action_space.low[0]  # 动作最小值
print(f'state_dim = {state_dim}')
print(f'action_dim = {action_dim}')
print(f'action_low = {action_low}')
print(f'action_high = {action_high}')


alg_name = 'PPO'

if alg_name == 'TRPO':
    lmbda = 0.9
    critic_lr = 1e-2
    kl_constraint = 0.00005
    alpha = 0.5
    from on_policy.alg_TRPO_Continuous import TRPOContinuous
    agent = TRPOContinuous(hidden_dim, env.observation_space, env.action_space,lmbda, kl_constraint, alpha, critic_lr, discount_factor, device)
elif alg_name == 'PPO':
    actor_lr = 1e-4
    critic_lr = 5e-3
    para_GAE_lmbda = 0.9
    epochs = 10
    para_PPO_clip = 0.2
    from on_policy.alg_PPO_Continuous import PPOContinuous
    agent = PPOContinuous(state_dim, hidden_dim, action_dim, actor_lr, critic_lr,para_GAE_lmbda, epochs, para_PPO_clip, discount_factor, device)


print('Training!!!!')
return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)
rl_utils.plot_results(return_list, env_name, alg_name, string_train_test = 'Training', moving_average_weight = 9)

print('Testing!!!!')
return_list_test = rl_utils.test_agent(env, agent, num_episodes = 50)
rl_utils.plot_results(return_list_test, env_name, alg_name, string_train_test = 'Testing', moving_average_weight = 3)
print('Rendering!!!!')
rl_utils.test_agent_render(env, agent)







# time_start = time.perf_counter()  # 记录开始时间

# time_end = time.perf_counter()  # 记录结束时间
# time_sum = time_end - time_start  # 计算的时间差为程序的执行时间，单位为秒/s
# print('time = %f' %time_sum)